iT邦幫忙

第 11 屆 iThome 鐵人賽

DAY 20
0
AI & Data

人工智慧(RL系列) 完爆遊戲30天系列 第 20

Day20 Dino train(下)

  • 分享至 

  • xImage
  •  

大部分內容都講完哩,今天來講Q訓練部分來做收尾~

訓練參數宣告

超過observe,就會開始訓練模型。

if t > OBSERVE:
    minibatch = random.sample(D, BATCH) # 從replay memory隨機抽取資料
    inputs = np.zeros((BATCH, s_t.shape[1], s_t.shape[2], s_t.shape[3])) #batch_size, 80, 80, 4
    targets = np.zeros((inputs.shape[0], ACTIONS)) # batch_size, actions

實作Q現實

根據batch大小,逐一實作Q現實。

for i in range(0, len(minibatch)):
    state_t = minibatch[i][0] # 圖像
    action_t = minibatch[i][1] # 執行的動作index
    reward_t = minibatch[i][2] # 輸入action後output的reward
    state_t1 = minibatch[i][3] # 輸入action後output的下個state
    terminal = minibatch[i][4] # 輸入action後output回報遊戲有無結束
    inputs[i:i + 1] = state_t # state塞回train要用的array
    targets[i] = model.predict(state_t) # Q估計
    Q_sa = model.predict(state_t1) # 下一步的Q估計
    if terminal:
        targets[i, action_t] = reward_t # 如果是最後步,這邊就只有reward
    else:
        targets[i, action_t] = reward_t + GAMMA * np.max(Q_sa) # Q現實

DQN訓練

loss += model.train_on_batch(inputs, targets)
loss_df.loc[len(loss_df)] = loss # 紀錄loss
q_values_df.loc[len(q_values_df)] = np.max(Q_sa) # 紀錄q_value

紀錄次數

s_t = initial_state if terminal else s_t1 
t = t + 1

如果訓練次數能被1000整除,則遊戲暫停,儲存資料。會直接暫停是因為資料IO要時間,怕影響到採樣。

if t % 1000 == 0:
    print("Now we save model")
    game_state._game.pause() #pause game while saving to filesystem
    model.save_weights("model.h5", overwrite=True)
    save_obj(D,"D") #saving episodes
    save_obj(t,"time") #caching time steps
    save_obj(epsilon,"epsilon") #cache epsilon to avoid repeated randomness in actions
    loss_df.to_csv("./objects/loss_df.csv",index=False)
    scores_df.to_csv("./objects/scores_df.csv",index=False)
    actions_df.to_csv("./objects/actions_df.csv",index=False)
    q_values_df.to_csv(q_value_file_path,index=False)
    with open("model.json", "w") as outfile:
        json.dump(model.to_json(), outfile)
    game_state._game.resume()

主程序

接下來把之前學的都寫進主程序,就可以開始訓練拉!

def playGame(observe=False):
    game = Game()
    dino = DinoAgent(game)
    game_state = Game_sate(dino,game)
    model = buildmodel()
    try:
        trainNetwork(model,game_state,observe=observe)
    except StopIteration:
        game.end()
playGame(observe=False)

程式碼實作

訓練小恐龍主程序unit5_dino

結語

小恐龍跳跳跳筆者是train了3天至1400多分,專案的原作者的demo則跑到4000多分,有個地方令我感到有些困擾,個人覺得在採樣的時候,樣本跟樣本時間間隔不一致,筆者猜因這個關係而導致收斂跟效果都沒想像中的好,畢竟你看小恐龍玩的規則很單純,但真的要走到高分的次數其實很少,就我觀察小恐龍有時會跳跳躍採到仙人掌或直接撞上,最好的表現是可以在快接近的時候跳躍。我的想法是小恐龍原本可收斂的很快,但因為採樣時間的不均一,導致效果差。

延伸下來我們可以思考幾個方式,例如兩隻程序一個是小恐龍的主程序走,負責採樣,另一支則輸出動作跟訓練,這當然會延伸其他問題例如兩支程序的速度不同,如何互相配合?類神經的速度如果能跟得上環境那還好說,但實際看起來卻並非如此。還有個最暴力方式就是不斷讓環境該停止就停,以此控制採樣跟執行action的間隔,兩個都有些想法,不過這可能要等之後有段時間再來實行了><

用11篇講解keras實作DQN,學到這邊同學有沒很有成就感呢?到這再玩其他強化學習專案就就可以很快上手囉!恭喜同學堅持到今天~接下來幾篇我們會講解進階的DQN方法以及自建環境,大家明天見拉~


上一篇
Day19 Dion train(中)
下一篇
Day21 Double DQN
系列文
人工智慧(RL系列) 完爆遊戲30天30
圖片
  直播研討會
圖片
{{ item.channelVendor }} {{ item.webinarstarted }} |
{{ formatDate(item.duration) }}
直播中

尚未有邦友留言

立即登入留言